import logging
from typing import List, Optional

import cv2
import numpy as np

from anylabeling.services.auto_labeling.engines.build_onnx_engine import OnnxBaseModel
from anylabeling.services.auto_labeling.utils.sahi.models.base import DetectionModel
from anylabeling.services.auto_labeling.utils.sahi.prediction import ObjectPrediction
from anylabeling.services.auto_labeling.utils.sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
from anylabeling.services.auto_labeling.utils.sahi.utils.import_utils import check_requirements

logger = logging.getLogger(__name__)

class Yolov5ONNX(object):
    def __init__(
            self,
            model_path: str,
            device: str,
            conf_thres: float,
            nms_thres: float
    ):
        self.net = OnnxBaseModel(model_path, device)
        self.conf_thres = conf_thres
        self.nms_thres = nms_thres

    def inference(self, image):
        blob, img_size = self.preprocess(image)
        outputs = self.net.get_ort_inference(blob)
        bboxes, scores, class_ids = self.postprocess(outputs, img_size)
        return bboxes, scores, class_ids

    @staticmethod
    def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleup=True, stride=32,
                  return_int=False):
        '''Resize and pad image while meeting stride-multiple constraints.'''
        shape = im.shape[:2]  # current shape [height, width]
        if isinstance(new_shape, int):
            new_shape = (new_shape, new_shape)
        elif isinstance(new_shape, list) and len(new_shape) == 1:
            new_shape = (new_shape[0], new_shape[0])

        # Scale ratio (new / old)
        r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
        if not scaleup:  # only scale down, do not scale up (for better val mAP)
            r = min(r, 1.0)

        # Compute padding
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding

        if auto:  # minimum rectangle
            dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding

        dw /= 2  # divide padding into 2 sides
        dh /= 2

        if shape[::-1] != new_unpad:  # resize
            im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
        top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
        left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
        im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
        if not return_int:
            return im, r, (dw, dh)
        else:
            return im, r, (left, top)

    def preprocess(self, input_image):
        """
        Pre-process the input image before feeding it to the network.
        """
        _, _, input_height, input_width = self.net.get_input_shape()

        image = self.letterbox(input_image, [input_height,input_width], stride=32)[0]
        image = image.transpose((2, 0, 1))  # HWC to CHW
        image = np.ascontiguousarray(image).astype('float32')
        image /= 255  # 0 - 255 to 0.0 - 1.0
        blob = np.expand_dims(image, axis=0)
        return blob, input_image.shape[:2]

    def clip_coords(self, boxes, img_shape):
        # Clip bounding xyxy bounding boxes to image shape (height, width)
        boxes[:, 0].clip(0, img_shape[1])  # x1
        boxes[:, 1].clip(0, img_shape[0])  # y1
        boxes[:, 2].clip(0, img_shape[1])  # x2
        boxes[:, 3].clip(0, img_shape[0])  # y2

    def scale_coords(self, img1_shape, coords, img0_shape, ratio_pad=None):
        # Rescale coords (xyxy) from img1_shape to img0_shape
        if ratio_pad is None:  # calculate from img0_shape
            gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1])  # gain  = old / new
            pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2  # wh padding
        else:
            gain = ratio_pad[0][0]
            pad = ratio_pad[1]

        coords[:, [0, 2]] -= pad[0]  # x padding
        coords[:, [1, 3]] -= pad[1]  # y padding
        coords[:, :4] /= gain
        self.clip_coords(coords, img0_shape)
        return coords

    def postprocess(self, outputs, img_size):
        """
        Post-process the network's output, to get the bounding boxes and
        their confidence scores.
        """
        _, _, input_height, input_width = self.net.get_input_shape()

        pred = outputs.astype(np.float32)
        pred = np.squeeze(pred, axis=0)

        boxes = []
        classIds = []
        confidences = []
        for detection in pred:
            scores = detection[5:]
            classID = np.argmax(scores)
            confidence = scores[classID] * detection[4]

            if confidence > self.conf_thres:
                box = detection[0:4]
                (centerX, centerY, width, height) = box.astype("int")
                x = int(centerX - (width / 2))
                y = int(centerY - (height / 2))

                boxes.append([x, y, int(width), int(height)])
                classIds.append(classID)
                confidences.append(float(confidence))

        idxs = cv2.dnn.NMSBoxes(
            boxes, confidences, self.conf_thres, self.nms_thres)

        pred_boxes = []
        pred_confes = []
        pred_classes = []
        _boxes = []
        if len(idxs) > 0:
            for i in idxs.flatten():
                confidence = confidences[i]
                if confidence >= self.conf_thres:
                    pred_boxes.append(boxes[i])
                    pred_confes.append(confidence)
                    pred_classes.append(classIds[i])

        if len(pred_boxes) > 0:
            for i, _ in enumerate(pred_boxes):
                box = pred_boxes[i]
                left, top, width, height = box[0], box[1], box[2], box[3]
                box = (left, top, left + width, top + height)
                box = np.squeeze(
                    self.scale_coords((input_height,input_width), np.expand_dims(box, axis=0).astype("float"), img_size).round(),
                    axis=0).astype("int")
                _boxes.append([box[0], box[1], box[2], box[3]])

        return np.array(_boxes), np.array(pred_confes), np.array(pred_classes)


class Yolov5OnnxDetectionModel(DetectionModel):
    def check_dependencies(self) -> None:
        check_requirements(["onnxruntime"])

    def load_model(self):
        """
        Detection model is initialized and set to self.model.
        """
        self.conf_thres = self.confidence_threshold
        self.nms_thres = self.nms_threshold

        # set model
        self.model = Yolov5ONNX(
            model_path=self.model_path,
            device=self.device,
            conf_thres=self.conf_thres,
            nms_thres=self.nms_thres
        )

        # set category list
        self.category_name_list = list(self.category_mapping.values())
        self.category_name_list_len = len(self.category_name_list)

    def perform_inference(self, image: np.ndarray, image_size: int = None):
        """
        Prediction is performed using self.model and the prediction result is set to self._original_predictions.
        Args:
            image: np.ndarray
                A numpy array that contains the image to be predicted. 3 channel image should be in BGR order.
            image_size: int
                Inference input size.
        """

        # Confirm model is loaded
        assert self.model is not None, "Model is not loaded, load it by calling .load_model()"

        prediction_result = self.model.inference(image)

        self._original_predictions = [prediction_result]

    @property
    def num_categories(self):
        return self.category_name_list_len

    @property
    def has_mask(self):
        return False

    @property
    def category_names(self):
        return self.category_name_list

    def _create_object_prediction_list_from_original_predictions(
            self,
            shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
            full_shape_list: Optional[List[List[int]]] = None,
    ):
        """
        self._original_predictions is converted to a list of prediction.ObjectPrediction and set to
        self._object_prediction_list_per_image.
        Args:
            shift_amount_list: list of list
                To shift the box and mask predictions from sliced image to full sized image, should
                be in the form of List[[shift_x, shift_y],[shift_x, shift_y],...]
            full_shape_list: list of list
                Size of the full image after shifting, should be in the form of
                List[[height, width],[height, width],...]
        """
        original_predictions = self._original_predictions

        # compatilibty for sahi v0.8.15
        shift_amount_list = fix_shift_amount_list(shift_amount_list)
        full_shape_list = fix_full_shape_list(full_shape_list)

        # handle all predictions
        object_prediction_list_per_image = []
        for image_ind, original_prediction in enumerate(original_predictions):
            bboxes = original_prediction[0]
            scores = original_prediction[1]
            class_ids = original_prediction[2]

            shift_amount = shift_amount_list[image_ind]
            full_shape = None if full_shape_list is None else full_shape_list[
                image_ind]
            object_prediction_list = []

            # process predictions
            for original_bbox, score, class_id in zip(bboxes, scores,
                                                      class_ids):
                x1 = int(original_bbox[0])
                y1 = int(original_bbox[1])
                x2 = int(original_bbox[2])
                y2 = int(original_bbox[3])
                bbox = [x1, y1, x2, y2]
                score = score
                category_id = int(class_id)
                category_name = self.category_mapping[str(category_id)]

                # fix negative box coords
                bbox[0] = max(0, bbox[0])
                bbox[1] = max(0, bbox[1])
                bbox[2] = max(0, bbox[2])
                bbox[3] = max(0, bbox[3])

                # fix out of image box coords
                if full_shape is not None:
                    bbox[0] = min(full_shape[1], bbox[0])
                    bbox[1] = min(full_shape[0], bbox[1])
                    bbox[2] = min(full_shape[1], bbox[2])
                    bbox[3] = min(full_shape[0], bbox[3])

                # ignore invalid predictions
                if not (bbox[0] < bbox[2]) or not (bbox[1] < bbox[3]):
                    logger.warning(
                        f"ignoring invalid prediction with bbox: {bbox}")
                    continue

                object_prediction = ObjectPrediction(
                    bbox=bbox,
                    category_id=category_id,
                    score=score,
                    bool_mask=None,
                    category_name=category_name,
                    shift_amount=shift_amount,
                    full_shape=full_shape,
                )
                object_prediction_list.append(object_prediction)
            object_prediction_list_per_image.append(object_prediction_list)

        self._object_prediction_list_per_image = object_prediction_list_per_image